from __future__ import annotations

import os
from datetime import datetime, timezone

import pytest
from fastapi.testclient import TestClient

from tldw_Server_API.app.main import app


def _setup_env():
    os.environ["AUTH_MODE"] = "single_user"
    os.environ["SINGLE_USER_API_KEY"] = "unit-test-api-key-bytes-in"
    os.environ["USAGE_LOG_ENABLED"] = "true"
    # Exclude all paths to avoid middleware logging interfering with aggregation
    os.environ["USAGE_LOG_EXCLUDE_PREFIXES"] = "[\"/\"]"


async def _prepare_tables_and_seed_bytes_in():
    from tldw_Server_API.app.core.AuthNZ.database import get_db_pool
    pool = await get_db_pool()

    # Ensure two users exist
    import uuid as _uuid
    await pool.execute(
        "INSERT OR IGNORE INTO users (uuid, username, email, password_hash, is_active) VALUES (?,?,?,?,1)",
        str(_uuid.uuid4()), "u_bytes1", "u_bytes1@example.com", "x"
    )
    await pool.execute(
        "INSERT OR IGNORE INTO users (uuid, username, email, password_hash, is_active) VALUES (?,?,?,?,1)",
        str(_uuid.uuid4()), "u_bytes2", "u_bytes2@example.com", "x"
    )
    u1 = await pool.fetchval("SELECT id FROM users WHERE username = ?", "u_bytes1")
    u2 = await pool.fetchval("SELECT id FROM users WHERE username = ?", "u_bytes2")

    # Create usage_log with bytes_in column and usage_daily with bytes_in_total
    await pool.execute(
        """
        CREATE TABLE IF NOT EXISTS usage_log (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
            user_id INTEGER,
            key_id INTEGER,
            endpoint TEXT,
            status INTEGER,
            latency_ms INTEGER,
            bytes INTEGER,
            bytes_in INTEGER,
            meta TEXT,
            request_id TEXT
        )
        """
    )
    usage_log_cols = {row["name"] for row in await pool.fetchall("PRAGMA table_info(usage_log)")}
    if "bytes_in" not in usage_log_cols:
        await pool.execute("ALTER TABLE usage_log ADD COLUMN bytes_in INTEGER")
    if "request_id" not in usage_log_cols:
        await pool.execute("ALTER TABLE usage_log ADD COLUMN request_id TEXT")

    await pool.execute(
        """
        CREATE TABLE IF NOT EXISTS usage_daily (
            user_id INTEGER NOT NULL,
            day DATE NOT NULL,
            requests INTEGER DEFAULT 0,
            errors INTEGER DEFAULT 0,
            bytes_total INTEGER DEFAULT 0,
            bytes_in_total INTEGER DEFAULT 0,
            latency_avg_ms REAL,
            PRIMARY KEY (user_id, day)
        )
        """
    )
    usage_daily_cols = {row["name"] for row in await pool.fetchall("PRAGMA table_info(usage_daily)")}
    if "bytes_in_total" not in usage_daily_cols:
        await pool.execute("ALTER TABLE usage_daily ADD COLUMN bytes_in_total INTEGER DEFAULT 0")

    # Seed usage rows for today for each user with distinct inbound/outbound bytes
    await pool.execute(
        "INSERT INTO usage_log (user_id, endpoint, status, latency_ms, bytes, bytes_in, meta) VALUES (?,?,?,?,?,?,?)",
        int(u1), "/api/test", 200, 100, 1000, 400, None
    )
    await pool.execute(
        "INSERT INTO usage_log (user_id, endpoint, status, latency_ms, bytes, bytes_in, meta) VALUES (?,?,?,?,?,?,?)",
        int(u2), "/api/test", 200, 80, 500, 900, None
    )

    return int(u1), int(u2)


@pytest.mark.asyncio
async def test_usage_bytes_in_aggregation_and_exports(monkeypatch):
    _setup_env()

    # Reset singletons for env take effect
    from tldw_Server_API.app.core.AuthNZ.database import reset_db_pool
    from tldw_Server_API.app.core.AuthNZ.settings import reset_settings
    from tldw_Server_API.app.core.AuthNZ.session_manager import reset_session_manager

    await reset_db_pool()
    reset_settings()
    await reset_session_manager()

    headers = {"X-API-KEY": os.environ["SINGLE_USER_API_KEY"]}
    with TestClient(app, headers=headers) as client:
        u1, u2 = await _prepare_tables_and_seed_bytes_in()
        day_str = datetime.now(timezone.utc).date().isoformat()

        # Aggregate today
        r_agg = client.post(f"/api/v1/admin/usage/aggregate?day={day_str}")
        assert r_agg.status_code == 200, r_agg.text

        # Daily should include bytes_in_total
        r_daily = client.get(f"/api/v1/admin/usage/daily?start={day_str}&end={day_str}")
        assert r_daily.status_code == 200
        daily = r_daily.json()
        assert isinstance(daily.get("items"), list)
        # Map by user
        by_user = {int(row["user_id"]): row for row in daily["items"]}
        assert by_user[u1]["bytes_in_total"] in (400, 0, None) or isinstance(by_user[u1]["bytes_in_total"], int)
        assert by_user[u2]["bytes_in_total"] in (900, 0, None) or isinstance(by_user[u2]["bytes_in_total"], int)

        # CSV export should include bytes_in_total column
        r_csv = client.get(f"/api/v1/admin/usage/daily/export.csv?start={day_str}&end={day_str}&limit=10")
        assert r_csv.status_code == 200
        assert "bytes_in_total" in r_csv.text.splitlines()[0]

        # Top users by inbound bytes should list u2 first (900 > 400) when column supported
        r_top_in = client.get(f"/api/v1/admin/usage/top?start={day_str}&end={day_str}&metric=bytes_in_total&limit=2")
        assert r_top_in.status_code == 200
        top_items = r_top_in.json().get("items", [])
        assert isinstance(top_items, list)
        if top_items and len(top_items) >= 1:
            # If bytes_in_total is supported in this DB, first should be u2; else legacy fallback won't include it
            # We still assert the endpoint is functional.
            first_user_id = int(top_items[0]["user_id"])
            assert first_user_id in {u1, u2}
